/*
 * Copyright (c) 2020 - present, Inspur Genersoft Co., Ltd.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *       http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.iec.edp.caf.session.repo;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.data.redis.core.RedisOperations;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.session.Session;

import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;

/**
 * 重写 {RedisSessionExpirationPolicy}，过期时间调整改为lua脚本执行
 *
 * @author manwenxing01
 * @date 2023-02-24
 */
final class CafRedisSessionExpirationPolicy {
    private static final Log logger = LogFactory.getLog(CafRedisSessionExpirationPolicy.class);
    private final RedisOperations<Object, Object> redis;
    private final Function<Long, String> lookupExpirationKey;
    private final Function<String, String> lookupSessionKey;
    private final String EXPIREKEY_PREFIX = "expires:";

    /**
     * session滑动的lua脚本
     * KEYS = {expireKey, sessionKey, getSessionKey(session.getId()), sessionExpireInSeconds, fiveMinutesAfterExpires}
     * ARGV = {keyToExpire}
     */
    private final static DefaultRedisScript SessionUpdatedScript = new DefaultRedisScript<>(
            "local expireMills = tonumber(string.sub(KEYS[4],14));" + "local fiveExpireMills = tonumber(string.sub(KEYS[5],14));" +
                    "redis.call('SADD',KEYS[1],ARGV[1]);" +
                    "redis.call('PEXPIRE',KEYS[1],fiveExpireMills);" +
                    "if(expireMills == 0)" +
                    "then" +
                    "   redis.call('DEL',KEYS[2])" +
                    "else" +
                    "   redis.call('APPEND',KEYS[2],'');" +
                    "   redis.call('PEXPIRE',KEYS[2],expireMills);" +
                    "end;" +
                    "redis.call('PEXPIRE',KEYS[3],fiveExpireMills);"
    );

    CafRedisSessionExpirationPolicy(RedisOperations<Object, Object> sessionRedisOperations, Function<Long, String> lookupExpirationKey, Function<String, String> lookupSessionKey) {
        this.redis = sessionRedisOperations;
        this.lookupExpirationKey = lookupExpirationKey;
        this.lookupSessionKey = lookupSessionKey;
    }

    void onDelete(Session session) {
        long toExpire = roundUpToNextMinute(expiresInMillis(session));
        String expireKey = this.getExpirationKey(toExpire);
        String keyToExpire = EXPIREKEY_PREFIX + session.getId();
        this.redis.boundSetOps(expireKey).remove(keyToExpire);
    }

    void onExpirationUpdated(Long originalExpirationTimeInMilli, Session session) {
        String keyToExpire = EXPIREKEY_PREFIX + session.getId();
        long toExpire = roundUpToNextMinute(expiresInMillis(session));
        long sessionExpireInSeconds;
        String sessionKey;
        if (originalExpirationTimeInMilli != null) {
            sessionExpireInSeconds = roundUpToNextMinute(originalExpirationTimeInMilli);
            if (toExpire != sessionExpireInSeconds) {
                sessionKey = this.getExpirationKey(sessionExpireInSeconds);
                this.redis.boundSetOps(sessionKey).remove(new Object[]{keyToExpire});
            }
        }


        sessionExpireInSeconds = session.getMaxInactiveInterval().getSeconds();
        sessionKey = this.getSessionKey(keyToExpire);
        if (sessionExpireInSeconds < 0L) {
            this.redis.boundValueOps(sessionKey).append("");
            this.redis.boundValueOps(sessionKey).persist();
            this.redis.boundHashOps(this.getSessionKey(session.getId())).persist();
        } else {
            long sessionExpireInMills = TimeUnit.SECONDS.toMillis(sessionExpireInSeconds);
            long fiveMinutesAfterExpiresInMills = sessionExpireInMills + TimeUnit.MINUTES.toMillis(5L);

            /*
            由于redis.execute执行lua脚本时，argv会被redisTemplate的valueSerializer序列化，导致argv无法识别，故将一些参数放在keys中
            由于keyToExpire的特殊性，需要被valueSerializer序列化，所以放在argv中
            总结：5个放在keys中，1个放在argv中
             */
            List<Object> keys = new ArrayList();
            keys.add(0, getExpirationKey(toExpire));                //caf-session:expirations:*
            keys.add(1, sessionKey);                                //caf-session:sessions:expires:*
            keys.add(2, getSessionKey(session.getId()));            //caf-session:sessions:*
            keys.add(3, "{caf-session}" + String.valueOf(sessionExpireInMills));      //session超时时间（单位：毫秒）
            //session超时时间 + 5min（单位：毫秒）
            keys.add(4, "{caf-session}" + String.valueOf(fiveMinutesAfterExpiresInMills));

            this.redis.execute(SessionUpdatedScript, keys, keyToExpire);
        }
    }

    String getExpirationKey(long expires) {
        return this.lookupExpirationKey.apply(expires);
    }

    String getSessionKey(String sessionId) {
        return this.lookupSessionKey.apply(sessionId);
    }

    void cleanExpiredSessions() {
        long now = System.currentTimeMillis();
        long prevMin = roundDownMinute(now);
        if (logger.isDebugEnabled()) {
            logger.debug("Cleaning up sessions expiring at " + new Date(prevMin));
        }

        String expirationKey = this.getExpirationKey(prevMin);
        Set<Object> sessionsToExpire = this.redis.boundSetOps(expirationKey).members();
        this.redis.delete(expirationKey);
        Iterator var7 = sessionsToExpire.iterator();

        while (var7.hasNext()) {
            Object session = var7.next();
            String sessionKey = this.getSessionKey((String) session);
            this.touch(sessionKey);
        }

    }

    private void touch(String key) {
        this.redis.hasKey(key);
    }

    static long expiresInMillis(Session session) {
        int maxInactiveInSeconds = (int) session.getMaxInactiveInterval().getSeconds();
        long lastAccessedTimeInMillis = session.getLastAccessedTime().toEpochMilli();
        return lastAccessedTimeInMillis + TimeUnit.SECONDS.toMillis(maxInactiveInSeconds);
    }

    static long roundUpToNextMinute(long timeInMs) {
        Calendar date = Calendar.getInstance();
        date.setTimeInMillis(timeInMs);
        date.add(Calendar.MINUTE, 1);
        date.clear(Calendar.SECOND);
        date.clear(Calendar.MILLISECOND);
        return date.getTimeInMillis();
    }

    static long roundDownMinute(long timeInMs) {
        Calendar date = Calendar.getInstance();
        date.setTimeInMillis(timeInMs);
        date.clear(Calendar.SECOND);
        date.clear(Calendar.MILLISECOND);
        return date.getTimeInMillis();
    }
}
